Spark.GBDT学习-GBTClassifier

用于分类的GBT(Gradient-Boosted Trees)算法,基于J.H. Friedman. "Stochastic Gradient Boosting"实现,目前不支持多分类任务。Gradient Boosting vs. TreeBoost:

  • 本实现基于Stochastic Gradient Boosting(随机梯度提升),而不是TreeBoost
  • 两种方法都是通过最小化损失函数,学习树的集成
  • TreeBoost方法相对于原始方法,基于损失函数对叶节点的输出进行了额外的修改
  • Spark考虑未来实现TreeBoost

GBTClassifier

定义

一个唯一标识uid,继承了Predictor类,继承了GBTClassifierParamsDefaultParamsWritableLogging特质。其中Predictor中的三个元素分别代表: 特征类型、学习器、学习到用于预测的模型

class GBTClassifier(override val uid: String) 
extends Predictor[Vector, GBTClassifier, GBTClassificationModel] 
with GBTClassifierParams with DefaultParamsWritable with Logging 
{
    def this() = this(Identifiable.randomUID("gbtc"))
    ...
}

参数

为了兼容JAVA API,覆盖了继承自特质(with trait)的参数setter方法。

  1. TreeClassifierParams参数
  • maxDepth
    树的最大深度,0意味着只有一个叶节点,1意味着有一个内部节点+两个叶节点。
    支持:>=0
    默认:5
  • maxBins
    用于离散连续特征的最大分桶数,用于每个节点特征分裂时分裂点的选择,分桶数越大意味着粒度越高。
    支持:>=2并且>=任一类别特征的分类数
    默认:32
  • minInstancesPerNode
    分裂后每个子节点含有的最小样本数,如果分裂后左孩子或右孩子含有的样本数少于该值,则该分裂无效。
    支持:>=1
    默认:1
  • minInfoGain
    树节点分裂时的最小信息增益。
    支持:>=0.0
    默认:0.0
  • maxMemoryInMB
    每次会对一组节点进行切分,分组是按照树的层次逐步进行。每组需要切分的节点个数视内存大小而定,如果内存太小,每次只能切分一个节点。单位MB
    默认:256MB
  • cacheNodeIds
    如果为true,算法会为每个实例缓存树节点ID;如果为false,算法会将树传递给执行器用于匹配实例和树节点。缓存有利于加速训练深度较大的树,用户可以通过参数checkpointInterval设置缓存被检查的频率或者不检查。
    默认:false
  • checkpointInterval
    表示缓存的树节点ID的检查频率,当cacheNodeIds为true并且检查目录(checkpoint directory)通过sparkContext设置过才有效。
    支持:>=1或者-1代表不检查,10意味着每10次迭代检查一次。
    默认:10
  • impurity
    用于计算信息增益的准则。不支持通过GBTClassifier.setImpurity方法设置该值。
    支持:entropy、gini
    默认:gini
  1. TreeEnsembleParams参数
  • subsamplingRate
    每一次迭代训练基学习器(决策树)时所使用的训练数据集的百分比。
    支持:(0, 1]
    默认:1.0
  • seed
    随机数种子
    默认:this.getClass.getName.hashCode.toLong
  1. GBTParams参数
  • maxIter
    最大迭代次数
    支持:>=0
    默认:20
  • stepSize
    学习率(learning rate/step size)参数,用于缩小(shrinking)每个基学习器的贡献。
    支持:(0, 1]
    默认:0.1
  1. GBTClassifierParams参数
  • lossType
    GBT最小化的损失函数,不区分大小写。
    支持:logistic
    默认:logistic

方法

  1. copy方法
    GBTClassifier的拷贝函数。
  2. train方法
    GBTClassifier类的主要方法,用于训练得到学习好的用于预测的模型。
// @input: 训练数据, DataSet
// @output: 学习到的模型, GBTClassificationModel
override protected def train(dataset: Dataset[_]):
GBTClassificationModel = {
    // 得到类别特征
    val categoricalFeatures: Map[Int, Int] =
    MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
    // 转换训练数据并进行验证
    // 将DataSet转换成RDD[LabeledPoint]
    // 只支持二分类,要求label为0或1
    val oldDataset: RDD[LabeledPoint] =
        dataset.select(col($(labelCol)), col($(featuresCol))).rdd.map {
            case Row(label: Double, features: Vector) =>
                require(label == 0 || label == 1, s"GBTClassifier was given dataset with invalid label $label.  Labels must be in {0,1}; note that GBTClassifier currently only supports binary classification.")
            LabeledPoint(label, features)
        }
    // 获得特征个数及boosting策略
    val numFeatures = oldDataset.first().features.size
    val boostingStrategy = super.getOldBoostingStrategy(categoricalFeatures, OldAlgo.Classification)
    // 用于记录日志
    val instr = Instrumentation.create(this, oldDataset)
    instr.logParams(params: _*)
    instr.logNumFeatures(numFeatures)
    instr.logNumClasses(2)
    // 调用GradientBoostedTrees训练得到一组学习器及其权重
    val (baseLearners, learnerWeights) = GradientBoostedTrees.run(oldDataset, boostingStrategy, $(seed))
    // 将学到的模型封装成GBTClassificationModel并返回
    val m = new GBTClassificationModel(uid, baseLearners, learnerWeights, numFeatures)
    instr.logSuccess(m)
    m
}

GBTClassifier对象

object GBTClassifier extends DefaultParamsReadable[GBTClassifier] {
    // final变量,访问支持的损失函数类型
    final val supportedLossTypes: Array[String] = GBTClassifierParams.supportedLossTypes
    // 从目录中加载GBTClassifier
    override def load(path: String): GBTClassifier = super.load(path)
}

GBTClassificationModel

用于分类的GBT模型,仅支持二分类,支持连续特征和类别特征。

定义

继承了PredictionModel类以及多个特质,其中PredictionModel的两个元素分别代表特征类型、学习到用于预测的模型

class GBTClassificationModel private[ml](
    override val uid: String,
    private val _trees: Array[DecisionTreeRegressionModel],
    private val _treeWeights: Array[Double],
    override val numFeatures: Int)
extends PredictionModel[Vector, GBTClassificationModel]
with GBTClassifierParams 
with TreeEnsembleModel[DecisionTreeRegressionModel]
with MLWritable with Serializable 
{
    // 检查_trees.nonEmpty
    // 检查_trees.length == _treeWeights.length
    val numTrees: Int = _trees.length
    ...
}

方法

  1. transformImpl方法
    首先将GBTClassificationModel进行广播,然后通过udf进行预测数据,udf中调用predict方法实现。
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
    // 广播本类
    val bcastModel = dataset.sparkSession.sparkContext.broadcast(this)
    val predictUDF = udf { (features: Any) =>
        // udf通过本类的predict方法实现
        bcastModel.value.predict(features.asInstanceOf[Vector])
    }
    // 使用udf将特征数据转换成预测数据
    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
  }
  1. predict方法
    关键的预测方法,先得到每个基学习器的预测值,然后进行融合得到最终的预测结果,最后得到类别结果。可以看到这里得到的预测值不是概率而是类别0/1,因为label被转换成了-1/+1,所以这里通过prediction>0.0得到预测lebel。
override protected def predict(features: Vector): Double = {
    // 得到每棵树的预测结果
    val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
    // 乘以权重之后求和得到融合结果
    val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
    // 得到预测lebel
    if (prediction > 0.0) 1.0 else 0.0
  }
  1. copy方法
    GBTClassificationModel的拷贝方法。
  2. toOld方法
    将ml的模型转换成mllib中老的API,ml域的私有方法。
private[ml] def toOld: OldGBTModel = {
    new OldGBTModel(OldAlgo.Classification, _trees.map(_.toOld), _treeWeights)
}
  1. write方法
    调用GBTClassificationModel对象的方法保存本模型。
override def write: MLWriter = new GBTClassificationModel.GBTClassificationModelWriter(this)

GBTClassificationModel对象

  1. fromOld方法
    从老的API中转换出当前模型
  2. GBTClassificationModelReader
    私有类,其中的load方法用于从目录中读取模型
  3. GBTClassificationModelWriter
    私有类,其中的saveImpl方法用于将本模型保存
  4. read方法
    新建GBTClassificationModelReader
  5. load方法
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 159,290评论 4 363
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 67,399评论 1 294
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 109,021评论 0 243
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 44,034评论 0 207
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 52,412评论 3 287
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 40,651评论 1 219
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 31,902评论 2 313
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 30,605评论 0 199
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 34,339评论 1 246
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 30,586评论 2 246
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 32,076评论 1 261
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 28,400评论 2 253
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 33,060评论 3 236
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 26,083评论 0 8
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 26,851评论 0 195
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 35,685评论 2 274
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 35,595评论 2 270

推荐阅读更多精彩内容